{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "#import SparkSession\n",
    "from pyspark.sql import SparkSession\n",
    "spark=SparkSession.builder.appName('random_forest').getOrCreate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "#read the dataset\n",
    "df=spark.read.csv('affairs.csv',inferSchema=True,header=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(6366, 6)\n"
     ]
    }
   ],
   "source": [
    "#check the shape of the data \n",
    "print((df.count(),len(df.columns)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- rate_marriage: integer (nullable = true)\n",
      " |-- age: double (nullable = true)\n",
      " |-- yrs_married: double (nullable = true)\n",
      " |-- children: double (nullable = true)\n",
      " |-- religious: integer (nullable = true)\n",
      " |-- affairs: integer (nullable = true)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "#printSchema\n",
    "df.printSchema()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-------------+----+-----------+--------+---------+-------+\n",
      "|rate_marriage| age|yrs_married|children|religious|affairs|\n",
      "+-------------+----+-----------+--------+---------+-------+\n",
      "|            5|32.0|        6.0|     1.0|        3|      0|\n",
      "|            4|22.0|        2.5|     0.0|        2|      0|\n",
      "|            3|32.0|        9.0|     3.0|        3|      1|\n",
      "|            3|27.0|       13.0|     3.0|        1|      1|\n",
      "|            4|22.0|        2.5|     0.0|        1|      1|\n",
      "+-------------+----+-----------+--------+---------+-------+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "#view the dataset\n",
    "df.show(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-------+------------------+------------------+-----------------+------------------+------------------+\n",
      "|summary|     rate_marriage|               age|      yrs_married|          children|         religious|\n",
      "+-------+------------------+------------------+-----------------+------------------+------------------+\n",
      "|  count|              6366|              6366|             6366|              6366|              6366|\n",
      "|   mean| 4.109644989004084|29.082862079798932| 9.00942507068803|1.3968740182218033|2.4261702796104303|\n",
      "| stddev|0.9614295945655025| 6.847881883668817|7.280119972766412| 1.433470828560344|0.8783688402641785|\n",
      "|    min|                 1|              17.5|              0.5|               0.0|                 1|\n",
      "|    max|                 5|              42.0|             23.0|               5.5|                 4|\n",
      "+-------+------------------+------------------+-----------------+------------------+------------------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "#Exploratory Data Analysis\n",
    "df.describe().select('summary','rate_marriage','age','yrs_married','children','religious').show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-------+-----+\n",
      "|affairs|count|\n",
      "+-------+-----+\n",
      "|      1| 2053|\n",
      "|      0| 4313|\n",
      "+-------+-----+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.groupBy('affairs').count().show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-------------+-----+\n",
      "|rate_marriage|count|\n",
      "+-------------+-----+\n",
      "|            1|   99|\n",
      "|            3|  993|\n",
      "|            5| 2684|\n",
      "|            4| 2242|\n",
      "|            2|  348|\n",
      "+-------------+-----+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.groupBy('rate_marriage').count().show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-------------+-------+-----+\n",
      "|rate_marriage|affairs|count|\n",
      "+-------------+-------+-----+\n",
      "|            1|      0|   25|\n",
      "|            1|      1|   74|\n",
      "|            2|      0|  127|\n",
      "|            2|      1|  221|\n",
      "|            3|      0|  446|\n",
      "|            3|      1|  547|\n",
      "|            4|      0| 1518|\n",
      "|            4|      1|  724|\n",
      "|            5|      0| 2197|\n",
      "|            5|      1|  487|\n",
      "+-------------+-------+-----+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.groupBy('rate_marriage','affairs').count().orderBy('rate_marriage','affairs','count',ascending=True).show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+---------+-------+-----+\n",
      "|religious|affairs|count|\n",
      "+---------+-------+-----+\n",
      "|        1|      0|  613|\n",
      "|        1|      1|  408|\n",
      "|        2|      0| 1448|\n",
      "|        2|      1|  819|\n",
      "|        3|      0| 1715|\n",
      "|        3|      1|  707|\n",
      "|        4|      0|  537|\n",
      "|        4|      1|  119|\n",
      "+---------+-------+-----+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.groupBy('religious','affairs').count().orderBy('religious','affairs','count',ascending=True).show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------+-------+-----+\n",
      "|children|affairs|count|\n",
      "+--------+-------+-----+\n",
      "|     0.0|      0| 1912|\n",
      "|     0.0|      1|  502|\n",
      "|     1.0|      0|  747|\n",
      "|     1.0|      1|  412|\n",
      "|     2.0|      0|  873|\n",
      "|     2.0|      1|  608|\n",
      "|     3.0|      0|  460|\n",
      "|     3.0|      1|  321|\n",
      "|     4.0|      0|  197|\n",
      "|     4.0|      1|  131|\n",
      "|     5.5|      0|  124|\n",
      "|     5.5|      1|   79|\n",
      "+--------+-------+-----+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.groupBy('children','affairs').count().orderBy('children','affairs','count',ascending=True).show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-------+------------------+------------------+------------------+------------------+------------------+------------+\n",
      "|affairs|avg(rate_marriage)|          avg(age)|  avg(yrs_married)|     avg(children)|    avg(religious)|avg(affairs)|\n",
      "+-------+------------------+------------------+------------------+------------------+------------------+------------+\n",
      "|      1|3.6473453482708234|30.537018996590355|11.152459814905017|1.7289332683877252| 2.261568436434486|         1.0|\n",
      "|      0| 4.329700904242986| 28.39067934152562| 7.989334569904939|1.2388128912589844|2.5045212149316023|         0.0|\n",
      "+-------+------------------+------------------+------------------+------------------+------------------+------------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.groupBy('affairs').mean().show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.feature import VectorAssembler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_assembler = VectorAssembler(inputCols=['rate_marriage', 'age', 'yrs_married', 'children', 'religious'], outputCol=\"features\")\n",
    "df = df_assembler.transform(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- rate_marriage: integer (nullable = true)\n",
      " |-- age: double (nullable = true)\n",
      " |-- yrs_married: double (nullable = true)\n",
      " |-- children: double (nullable = true)\n",
      " |-- religious: integer (nullable = true)\n",
      " |-- affairs: integer (nullable = true)\n",
      " |-- features: vector (nullable = true)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.printSchema()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----------------------+-------+\n",
      "|features               |affairs|\n",
      "+-----------------------+-------+\n",
      "|[5.0,32.0,6.0,1.0,3.0] |0      |\n",
      "|[4.0,22.0,2.5,0.0,2.0] |0      |\n",
      "|[3.0,32.0,9.0,3.0,3.0] |1      |\n",
      "|[3.0,27.0,13.0,3.0,1.0]|1      |\n",
      "|[4.0,22.0,2.5,0.0,1.0] |1      |\n",
      "|[4.0,37.0,16.5,4.0,3.0]|1      |\n",
      "|[5.0,27.0,9.0,1.0,1.0] |1      |\n",
      "|[4.0,27.0,9.0,0.0,2.0] |1      |\n",
      "|[5.0,37.0,23.0,5.5,2.0]|1      |\n",
      "|[5.0,37.0,23.0,5.5,2.0]|1      |\n",
      "+-----------------------+-------+\n",
      "only showing top 10 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.select(['features','affairs']).show(10,False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "#select data for building model\n",
    "model_df=df.select(['features','affairs'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df,test_df=model_df.randomSplit([0.75,0.25])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4800"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_df.count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-------+-----+\n",
      "|affairs|count|\n",
      "+-------+-----+\n",
      "|      1| 1574|\n",
      "|      0| 3226|\n",
      "+-------+-----+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "train_df.groupBy('affairs').count().show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-------+-----+\n",
      "|affairs|count|\n",
      "+-------+-----+\n",
      "|      1|  479|\n",
      "|      0| 1087|\n",
      "+-------+-----+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "test_df.groupBy('affairs').count().show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.classification import RandomForestClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "rf_classifier=RandomForestClassifier(labelCol='affairs',numTrees=50).fit(train_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "rf_predictions=rf_classifier.transform(test_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------------------+-------+--------------------+--------------------+----------+\n",
      "|            features|affairs|       rawPrediction|         probability|prediction|\n",
      "+--------------------+-------+--------------------+--------------------+----------+\n",
      "|[1.0,22.0,2.5,1.0...|      1|[18.7524598757082...|[0.37504919751416...|       1.0|\n",
      "|[1.0,27.0,2.5,0.0...|      1|[21.0021478623554...|[0.42004295724710...|       1.0|\n",
      "|[1.0,27.0,6.0,0.0...|      0|[20.1166611727778...|[0.40233322345555...|       1.0|\n",
      "|[1.0,27.0,6.0,2.0...|      1|[19.0905206218080...|[0.38181041243616...|       1.0|\n",
      "|[1.0,27.0,6.0,3.0...|      0|[16.3348579592130...|[0.32669715918426...|       1.0|\n",
      "|[1.0,27.0,9.0,4.0...|      0|[13.3128973003485...|[0.26625794600697...|       1.0|\n",
      "|[1.0,32.0,13.0,0....|      1|[16.7812008990910...|[0.33562401798182...|       1.0|\n",
      "|[1.0,32.0,13.0,2....|      1|[12.6966189294366...|[0.25393237858873...|       1.0|\n",
      "|[1.0,32.0,13.0,2....|      1|[12.6766379097881...|[0.25353275819576...|       1.0|\n",
      "|[1.0,32.0,16.5,2....|      1|[16.5966383870776...|[0.33193276774155...|       1.0|\n",
      "|[1.0,32.0,16.5,2....|      1|[16.5966383870776...|[0.33193276774155...|       1.0|\n",
      "|[1.0,32.0,16.5,3....|      1|[18.0962335004510...|[0.36192467000902...|       1.0|\n",
      "|[1.0,37.0,16.5,1....|      1|[17.1671412215484...|[0.34334282443096...|       1.0|\n",
      "|[1.0,37.0,16.5,2....|      1|[16.9033235593205...|[0.33806647118641...|       1.0|\n",
      "|[1.0,37.0,23.0,4....|      1|[12.9328949963349...|[0.25865789992669...|       1.0|\n",
      "|[1.0,37.0,23.0,5....|      1|[19.2882486885513...|[0.38576497377102...|       1.0|\n",
      "|[1.0,42.0,16.5,5....|      1|[14.7895576971207...|[0.29579115394241...|       1.0|\n",
      "|[1.0,42.0,23.0,2....|      1|[13.1159769754090...|[0.26231953950818...|       1.0|\n",
      "|[1.0,42.0,23.0,2....|      1|[16.8901830088537...|[0.33780366017707...|       1.0|\n",
      "|[1.0,42.0,23.0,2....|      1|[16.8901830088537...|[0.33780366017707...|       1.0|\n",
      "+--------------------+-------+--------------------+--------------------+----------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "rf_predictions.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------+-----+\n",
      "|prediction|count|\n",
      "+----------+-----+\n",
      "|       0.0| 1261|\n",
      "|       1.0|  305|\n",
      "+----------+-----+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "rf_predictions.groupBy('prediction').count().show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------------------------------------+-------+----------+\n",
      "|probability                             |affairs|prediction|\n",
      "+----------------------------------------+-------+----------+\n",
      "|[0.37504919751416455,0.6249508024858356]|1      |1.0       |\n",
      "|[0.42004295724710805,0.579957042752892] |1      |1.0       |\n",
      "|[0.40233322345555694,0.597666776544443] |0      |1.0       |\n",
      "|[0.3818104124361619,0.6181895875638381] |1      |1.0       |\n",
      "|[0.32669715918426007,0.6733028408157399]|0      |1.0       |\n",
      "|[0.26625794600697006,0.7337420539930299]|0      |1.0       |\n",
      "|[0.3356240179818214,0.6643759820181787] |1      |1.0       |\n",
      "|[0.25393237858873335,0.7460676214112667]|1      |1.0       |\n",
      "|[0.2535327581957624,0.7464672418042376] |1      |1.0       |\n",
      "|[0.3319327677415531,0.6680672322584469] |1      |1.0       |\n",
      "+----------------------------------------+-------+----------+\n",
      "only showing top 10 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "rf_predictions.select(['probability','affairs','prediction']).show(10,False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.evaluation import BinaryClassificationEvaluator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.evaluation import MulticlassClassificationEvaluator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "rf_accuracy=MulticlassClassificationEvaluator(labelCol='affairs',metricName='accuracy').evaluate(rf_predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The accuracy of RF on test data is 73%\n"
     ]
    }
   ],
   "source": [
    "print('The accuracy of RF on test data is {0:.0%}'.format(rf_accuracy))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7279693486590039\n"
     ]
    }
   ],
   "source": [
    "print(rf_accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "rf_precision=MulticlassClassificationEvaluator(labelCol='affairs',metricName='weightedPrecision').evaluate(rf_predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The precision rate on test data is 71%\n"
     ]
    }
   ],
   "source": [
    "print('The precision rate on test data is {0:.0%}'.format(rf_precision))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7085017563673452"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rf_precision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "rf_auc=BinaryClassificationEvaluator(labelCol='affairs').evaluate(rf_predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7421702296835062\n"
     ]
    }
   ],
   "source": [
    "print(rf_auc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Feature importance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SparseVector(5, {0: 0.5536, 1: 0.0364, 2: 0.2358, 3: 0.0803, 4: 0.0939})"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rf_classifier.featureImportances"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'numeric': [{'idx': 0, 'name': 'rate_marriage'},\n",
       "  {'idx': 1, 'name': 'age'},\n",
       "  {'idx': 2, 'name': 'yrs_married'},\n",
       "  {'idx': 3, 'name': 'children'},\n",
       "  {'idx': 4, 'name': 'religious'}]}"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.schema[\"features\"].metadata[\"ml_attr\"][\"attrs\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save the model "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'/home/jovyan/work'"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pwd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "rf_classifier.save(\"/home/jovyan/work/RF_model\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.classification import RandomForestClassificationModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "rf=RandomForestClassificationModel.load(\"/home/jovyan/work/RF_model\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_preditions=rf.transform(test_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------------------+-------+--------------------+--------------------+----------+\n",
      "|            features|affairs|       rawPrediction|         probability|prediction|\n",
      "+--------------------+-------+--------------------+--------------------+----------+\n",
      "|[1.0,22.0,2.5,1.0...|      1|[188.676639360932...|[0.37735327872186...|       1.0|\n",
      "|[1.0,27.0,2.5,0.0...|      1|[195.425833792250...|[0.39085166758450...|       1.0|\n",
      "|[1.0,27.0,6.0,0.0...|      0|[193.138478579040...|[0.38627695715808...|       1.0|\n",
      "|[1.0,27.0,6.0,2.0...|      1|[185.424877645536...|[0.37084975529107...|       1.0|\n",
      "|[1.0,27.0,6.0,3.0...|      0|[164.685852316351...|[0.32937170463270...|       1.0|\n",
      "|[1.0,27.0,9.0,4.0...|      0|[142.006095001922...|[0.28401219000384...|       1.0|\n",
      "|[1.0,32.0,13.0,0....|      1|[176.885399312490...|[0.35377079862498...|       1.0|\n",
      "|[1.0,32.0,13.0,2....|      1|[128.585405941664...|[0.25717081188332...|       1.0|\n",
      "|[1.0,32.0,13.0,2....|      1|[126.963464206019...|[0.25392692841203...|       1.0|\n",
      "|[1.0,32.0,16.5,2....|      1|[153.429839787005...|[0.30685967957401...|       1.0|\n",
      "|[1.0,32.0,16.5,2....|      1|[153.429839787005...|[0.30685967957401...|       1.0|\n",
      "|[1.0,32.0,16.5,3....|      1|[167.493344562280...|[0.33498668912456...|       1.0|\n",
      "|[1.0,37.0,16.5,1....|      1|[150.090425556233...|[0.30018085111246...|       1.0|\n",
      "|[1.0,37.0,16.5,2....|      1|[154.326506925977...|[0.30865301385195...|       1.0|\n",
      "|[1.0,37.0,23.0,4....|      1|[123.559481897677...|[0.24711896379535...|       1.0|\n",
      "|[1.0,37.0,23.0,5....|      1|[192.317134975430...|[0.38463426995086...|       1.0|\n",
      "|[1.0,42.0,16.5,5....|      1|[151.350101320899...|[0.30270020264179...|       1.0|\n",
      "|[1.0,42.0,23.0,2....|      1|[126.937590310819...|[0.25387518062163...|       1.0|\n",
      "|[1.0,42.0,23.0,2....|      1|[156.647804216314...|[0.31329560843262...|       1.0|\n",
      "|[1.0,42.0,23.0,2....|      1|[156.647804216314...|[0.31329560843262...|       1.0|\n",
      "+--------------------+-------+--------------------+--------------------+----------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "model_preditions.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
